{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reference\n",
    "- https://medium.com/@neotheicebird/webcam-based-image-processing-in-ipython-notebooks-47c75a022514"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:38.815479Z",
     "start_time": "2017-12-27T10:43:38.352970Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import signal\n",
    "from IPython import display\n",
    "\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from skimage.transform import resize\n",
    "from keras.models import load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:38.829958Z",
     "start_time": "2017-12-27T10:43:38.828130Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cascade_path = '../model/cv2/haarcascade_frontalface_alt2.xml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:44.388891Z",
     "start_time": "2017-12-27T10:43:39.027900Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_path = '../model/keras/model/facenet_keras.h5'\n",
    "model = load_model(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:44.419587Z",
     "start_time": "2017-12-27T10:43:44.404326Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prewhiten(x):\n",
    "    if x.ndim == 4:\n",
    "        axis = (1, 2, 3)\n",
    "        size = x[0].size\n",
    "    elif x.ndim == 3:\n",
    "        axis = (0, 1, 2)\n",
    "        size = x.size\n",
    "    else:\n",
    "        raise ValueError('Dimension should be 3 or 4')\n",
    "\n",
    "    mean = np.mean(x, axis=axis, keepdims=True)\n",
    "    std = np.std(x, axis=axis, keepdims=True)\n",
    "    std_adj = np.maximum(std, 1.0/np.sqrt(size))\n",
    "    y = (x - mean) / std_adj\n",
    "    return y\n",
    "\n",
    "def l2_normalize(x, axis=-1, epsilon=1e-10):\n",
    "    output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:44.447934Z",
     "start_time": "2017-12-27T10:43:44.439987Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def calc_embs(imgs, margin, batch_size):\n",
    "    aligned_images = prewhiten(imgs)\n",
    "    pd = []\n",
    "    for start in range(0, len(aligned_images), batch_size):\n",
    "        pd.append(model.predict_on_batch(aligned_images[start:start+batch_size]))\n",
    "    embs = l2_normalize(np.concatenate(pd))\n",
    "\n",
    "    return embs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:44.843761Z",
     "start_time": "2017-12-27T10:43:44.469101Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class FaceDemo(object):\n",
    "    def __init__(self, cascade_path):\n",
    "        self.vc = None\n",
    "        self.cascade = cv2.CascadeClassifier(cascade_path)\n",
    "        self.margin = 10\n",
    "        self.batch_size = 1\n",
    "        self.n_img_per_person = 10\n",
    "        self.is_interrupted = False\n",
    "        self.data = {}\n",
    "        self.le = None\n",
    "        self.clf = None\n",
    "        \n",
    "    def _signal_handler(self, signal, frame):\n",
    "        self.is_interrupted = True\n",
    "        \n",
    "    def capture_images(self, name='Unknown'):\n",
    "        vc = cv2.VideoCapture(0)\n",
    "        self.vc = vc\n",
    "        if vc.isOpened():\n",
    "            is_capturing, _ = vc.read()\n",
    "        else:\n",
    "            is_capturing = False\n",
    "\n",
    "        imgs = []\n",
    "        signal.signal(signal.SIGINT, self._signal_handler)\n",
    "        self.is_interrupted = False\n",
    "        while is_capturing:\n",
    "            is_capturing, frame = vc.read()\n",
    "            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "            faces = self.cascade.detectMultiScale(frame,\n",
    "                                         scaleFactor=1.1,\n",
    "                                         minNeighbors=3,\n",
    "                                         minSize=(100, 100))\n",
    "            if len(faces) != 0:\n",
    "                face = faces[0]\n",
    "                (x, y, w, h) = face\n",
    "                left = x - self.margin // 2\n",
    "                right = x + w + self.margin // 2\n",
    "                bottom = y - self.margin // 2\n",
    "                top = y + h + self.margin // 2\n",
    "                img = resize(frame[bottom:top, left:right, :],\n",
    "                             (160, 160), mode='reflect')\n",
    "                imgs.append(img)\n",
    "                cv2.rectangle(frame,\n",
    "                              (left-1, bottom-1),\n",
    "                              (right+1, top+1),\n",
    "                              (255, 0, 0), thickness=2)\n",
    "\n",
    "            plt.imshow(frame)\n",
    "            plt.title('{}/{}'.format(len(imgs), self.n_img_per_person))\n",
    "            plt.xticks([])\n",
    "            plt.yticks([])\n",
    "            display.clear_output(wait=True)\n",
    "            if len(imgs) == self.n_img_per_person:\n",
    "                vc.release()\n",
    "                self.data[name] = np.array(imgs)\n",
    "                break\n",
    "            try:\n",
    "                plt.pause(0.1)\n",
    "            except Exception:\n",
    "                pass\n",
    "            if self.is_interrupted:\n",
    "                vc.release()\n",
    "                break\n",
    "                \n",
    "    def train(self):\n",
    "        labels = []\n",
    "        embs = []\n",
    "        names = self.data.keys()\n",
    "        for name, imgs in self.data.items():\n",
    "            embs_ = calc_embs(imgs, self.margin, self.batch_size)    \n",
    "            labels.extend([name] * len(embs_))\n",
    "            embs.append(embs_)\n",
    "\n",
    "        embs = np.concatenate(embs)\n",
    "        le = LabelEncoder().fit(labels)\n",
    "        y = le.transform(labels)\n",
    "        clf = SVC(kernel='linear', probability=True).fit(embs, y)\n",
    "        \n",
    "        self.le = le\n",
    "        self.clf = clf\n",
    "        \n",
    "    def infer(self):\n",
    "        vc = cv2.VideoCapture(0)\n",
    "        self.vc = vc\n",
    "        if vc.isOpened():\n",
    "            is_capturing, _ = vc.read()\n",
    "        else:\n",
    "            is_capturing = False\n",
    "\n",
    "        signal.signal(signal.SIGINT, self._signal_handler)\n",
    "        self.is_interrupted = False\n",
    "        while is_capturing:\n",
    "            is_capturing, frame = vc.read()\n",
    "            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "            faces = self.cascade.detectMultiScale(frame,\n",
    "                                         scaleFactor=1.1,\n",
    "                                         minNeighbors=3,\n",
    "                                         minSize=(100, 100))\n",
    "            pred = None\n",
    "            if len(faces) != 0:\n",
    "                face = faces[0]\n",
    "                (x, y, w, h) = face\n",
    "                left = x - self.margin // 2\n",
    "                right = x + w + self.margin // 2\n",
    "                bottom = y - self.margin // 2\n",
    "                top = y + h + self.margin // 2\n",
    "                img = resize(frame[bottom:top, left:right, :],\n",
    "                             (160, 160), mode='reflect')\n",
    "                embs = calc_embs(img[np.newaxis], self.margin, 1)\n",
    "                pred = self.le.inverse_transform(self.clf.predict(embs))\n",
    "                cv2.rectangle(frame,\n",
    "                              (left-1, bottom-1),\n",
    "                              (right+1, top+1),\n",
    "                              (255, 0, 0), thickness=2)\n",
    "            plt.imshow(frame)\n",
    "            plt.title(pred)\n",
    "            plt.xticks([])\n",
    "            plt.yticks([])\n",
    "            display.clear_output(wait=True)\n",
    "            try:\n",
    "                plt.pause(0.1)\n",
    "            except Exception:\n",
    "                pass\n",
    "            if self.is_interrupted:\n",
    "                vc.release()\n",
    "                break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:43:44.873391Z",
     "start_time": "2017-12-27T10:43:44.861742Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = FaceDemo(cascade_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:44:10.714335Z",
     "start_time": "2017-12-27T10:44:06.972864Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train with two or more people\n",
    "f.capture_images('nyoki-mtl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:44:17.284697Z",
     "start_time": "2017-12-27T10:44:15.933323Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:44:36.348569Z",
     "start_time": "2017-12-27T10:44:17.930893Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f.infer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
